#include "Device.h"

#include <cstring>
#include <structmember.h>
#include <sstream>
#include "torch/csrc/Exceptions.h"
#include "torch/csrc/utils/object_ptr.h"
#include "torch/csrc/utils/python_arg_parser.h"
#include "torch/csrc/utils/python_strings.h"

PyObject *THPDevice_New(const torch::Device& device)
{
  auto type = (PyTypeObject*)&THPDeviceType;
  auto self = THPObjectPtr{type->tp_alloc(type, 0)};
  if (!self) throw python_error();
  auto self_ = reinterpret_cast<THPDevice*>(self.get());
  self_->device = device;
  return self.release();
}

static const char* cuda_str = "cuda";
static const char* cpu_str = "cpu";

static inline const char* deviceTypeString(torch::DeviceType device_type) {
  switch (device_type) {
    case torch::DeviceType::CUDA:
      return cuda_str;
    case torch::DeviceType::CPU:
      return cpu_str;
    default:
      throw std::runtime_error("unexpected device type");
  }
}

PyObject *THPDevice_repr(THPDevice *self)
{
  std::ostringstream oss;
  oss << "device(type=\'" << deviceTypeString(self->device.type) << "\'";
  if (!self->device.is_default) {
    oss << ", index=" << self->device.index;
  }
  oss << ")";
  return THPUtils_packString(oss.str().c_str());
}

PyObject *THPDevice_str(THPDevice*self)
{
  std::ostringstream oss;
  if (!self->device.is_default) {
    oss << deviceTypeString(self->device.type) << ":" << self->device.index;
  } else {
    oss << deviceTypeString(self->device.type);
  }
  return THPUtils_packString(oss.str().c_str());
}

PyObject *THPDevice_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs)
{
  HANDLE_TH_ERRORS
  static torch::PythonArgParser parser({
    "Device(Device device)",
    "Device(std::string type, int64_t? index=-1)"
  });
  torch::ParsedArgs<2> parsed_args;
  auto r = parser.parse(args, kwargs, parsed_args);
  if (r.idx == 0) {
    auto device = r.device(0);
    return THPDevice_New(device);
  } else if (r.idx == 1) {
    auto as_device = r.device(0);  // this works, because device can take strings
    auto device_type = r.string(0);
    if (!as_device.is_default) {
      throw std::runtime_error("type (string) must not include an index because index "
                                "was passed explicitly: " + device_type);
    }

    auto is_default = r.isNone(1);
    auto device_index = r.toInt64WithDefault(1, -1);
    // make sure this is constructible
    auto device = torch::Device(as_device.type, device_index, is_default);
    return THPDevice_New(device);
  }
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

PyObject *THPDevice_type(THPDevice *self)
{
  HANDLE_TH_ERRORS
  return THPUtils_packString(deviceTypeString(self->device.type));
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

PyObject *THPDevice_index(THPDevice *self)
{
  HANDLE_TH_ERRORS
  if (self->device.is_default) {
    Py_RETURN_NONE;
  } else {
    return THPUtils_packInt64(self->device.index);
  }
  END_HANDLE_TH_ERRORS
}

PyObject *THPDevice_rc(PyObject *a, PyObject *b, int op) {
  HANDLE_TH_ERRORS
  if (!THPDevice_Check(a) || !THPDevice_Check(b)) {
    // Py_RETURN_NOTIMPLEMENTED not in python 2.
    Py_INCREF(Py_NotImplemented);
    return Py_NotImplemented;
  }
  THPDevice *da = reinterpret_cast<THPDevice*>(a);
  THPDevice *db = reinterpret_cast<THPDevice*>(b);

  switch(op) {
    case Py_EQ:
      if (da->device == db->device) {
        Py_RETURN_TRUE;
      } else {
        Py_RETURN_FALSE;
      }
    case Py_NE:
      if (da->device == db->device) {
        Py_RETURN_FALSE;
      } else {
        Py_RETURN_TRUE;
      }
    case Py_LT:
    case Py_LE:
    case Py_GT:
    case Py_GE:
      throw torch::TypeError("comparison not implemented");
    default:
      throw torch::TypeError("unexpected comparison op");
  }
  END_HANDLE_TH_ERRORS
}

typedef PyObject *(*getter)(PyObject *, void *);

static struct PyGetSetDef THPDevice_properties[] = {
  {"type",       (getter)THPDevice_type, nullptr, nullptr, nullptr},
  {"index",      (getter)THPDevice_index, nullptr, nullptr, nullptr},
  {nullptr}
};

PyTypeObject THPDeviceType = {
  PyVarObject_HEAD_INIT(nullptr, 0)
  "torch.Device",                        /* tp_name */
  sizeof(THPDevice),                     /* tp_basicsize */
  0,                                     /* tp_itemsize */
  0,                                     /* tp_dealloc */
  0,                                     /* tp_print */
  0,                                     /* tp_getattr */
  0,                                     /* tp_setattr */
  0,                                     /* tp_reserved */
  (reprfunc)THPDevice_repr,              /* tp_repr */
  0,                                     /* tp_as_number */
  0,                                     /* tp_as_sequence */
  0,                                     /* tp_as_mapping */
  0,                                     /* tp_hash  */
  0,                                     /* tp_call */
  (reprfunc)THPDevice_str,               /* tp_str */
  0,                                     /* tp_getattro */
  0,                                     /* tp_setattro */
  0,                                     /* tp_as_buffer */
  Py_TPFLAGS_DEFAULT,                    /* tp_flags */
  nullptr,                               /* tp_doc */
  0,                                     /* tp_traverse */
  0,                                     /* tp_clear */
  (richcmpfunc)THPDevice_rc,             /* tp_richcompare */
  0,                                     /* tp_weaklistoffset */
  0,                                     /* tp_iter */
  0,                                     /* tp_iternext */
  0,                                     /* tp_methods */
  0,                                     /* tp_members */
  THPDevice_properties,                  /* tp_getset */
  0,                                     /* tp_base */
  0,                                     /* tp_dict */
  0,                                     /* tp_descr_get */
  0,                                     /* tp_descr_set */
  0,                                     /* tp_dictoffset */
  0,                                     /* tp_init */
  0,                                     /* tp_alloc */
  THPDevice_pynew,                       /* tp_new */
};

void THPDevice_init(PyObject *module)
{
  if (PyType_Ready(&THPDeviceType) < 0) {
    throw python_error();
  }
  Py_INCREF(&THPDeviceType);
  if (PyModule_AddObject(module, "device", (PyObject *)&THPDeviceType) != 0) {
    throw python_error();
  }
}
